
import json
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

def real_hal_sim():
    att_dict = torch.load(os.path.join('./', 'att_dict_ori.pth'))
    att_dict_test = torch.load(os.path.join('./', 'att_dict_test.pth'))
    avg_att_real = []
    avg_att_hal = []
    avg_att_diff = []

    # 按照key的顺序（0-31）处理数据
    for key in sorted(att_dict.keys()):
        sublists = att_dict[key]
        sublists_test = att_dict_test[key]
        
        if not sublists:
            # 没有数据时添加默认值
            avg_att_real.append(0)
            avg_att_hal.append(0)
            avg_att_diff.append(0)
            continue
        
        # 将数据转换为张量计算平均
        data_tensor = torch.tensor(sublists)  # [n_samples, 3]
        means = data_tensor.mean(dim=0)  # [3]

        data_tensor_test = torch.tensor(sublists_test)  # [n_samples, 3]
        means_test = data_tensor_test.mean(dim=0)  # [3]
        
        avg_att_real.append(means[0].item())
        avg_att_hal.append(means[1].item())
        avg_att_diff.append(means_test[2].item())

    # 步骤2: 创建包含所有三个指标的折线图
    x = np.arange(0, 32)  # 0-31

    plt.figure(figsize=(12, 8))

    # 绘制三条折线
    plt.plot(x, avg_att_real, 'o-', color='blue', linewidth=2, markersize=6, label='Real Att:Img - Txt')
    plt.plot(x, avg_att_hal, 's-', color='green', linewidth=2, markersize=6, label='Hal Att:Img - Txt')
    plt.plot(x, avg_att_diff, 'D-', color='red', linewidth=2, markersize=6, label='Diff: Real - Hal')

    # 添加标题和标签
    plt.title('Comparison of Attention Measures by Key Index', fontsize=16)
    plt.xlabel('Key Index (0-31)', fontsize=12)
    plt.ylabel('Average Value', fontsize=12)

    # 设置刻度、网格和图例
    plt.xticks(np.arange(0, 32, 2))  # 每隔2显示一个刻度
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend(fontsize=12, loc='best')  # 最佳位置放置图例

    def add_value_labels():
        """在每个数据点上方添加值标签（垂直偏移避免重叠）"""
        y_offset = max(max(avg_att_real), max(avg_att_hal), max(avg_att_diff)) * 0.05
        for i, (r, h, d) in enumerate(zip(avg_att_real, avg_att_hal, avg_att_diff)):
            if i % 2 == 0:  # 仅对偶数索引添加标签避免拥挤
                plt.text(i, r + y_offset, f"{r:.4f}", ha='center', fontsize=8, color='blue')
                plt.text(i, h + y_offset, f"{h:.4f}", ha='center', fontsize=8, color='green')
                plt.text(i, d + y_offset, f"{d:.4f}", ha='center', fontsize=8, color='red')


    add_value_labels()  # 取消此行可移除数据标签

    # 显示图表
    plt.tight_layout()
    # plt.show()
    plt.savefig('./real_hal_sim_final.png')

real_hal_sim()